Writing a custom activation function

Activation functions (also called transforms in neon) are nonlinearities such as rectified linear, softmax, or hyperbolic tangent functions.

Generally these functions are bundled together with a neural network layer such as Affine or Convolution.

We derive the Transform interface in neon, which specifies a call and bprop function. call is used during forward pass computation, and bprop is used to compute the derivative.


In [ ]:
from neon.backends import gen_backend
from neon.transforms.transform import Transform
be = gen_backend('gpu', batch_size=128)

class MySoftmax(Transform):
    """
    SoftMax activation function. Ensures that the activation output sums to 1.
    """
    def __init__(self, name=None):
        super(MySoftmax, self).__init__(name)

    def __call__(self, x):
        """
        Implement softmax(x) = e^(x-max(x)) / sum(e^(x-max(x))). 
        
        Input x has shape (# features, batch_size) 
        
        Return softmax(x), with shape (# features, batch_size), but where the features sum to 1.
                
        """

        expx = self.be.exp(x - self.be.max(x, axis=0))
        return expx / self.be.sum(expx, axis=0)
    
    def bprop(self, x):
        """
        We take a shortcut here- the derivative cancels out with a term in the CrossEntropy derivative.
        """
        return 1

Test our softmax

1) make some test data, on the host 2) move the test data to the device (GPU) 3) calculate softmax on our test data 4) copy the result back to the host, and inspect that it's correct.


In [ ]:
# generate some test data using numpy
import numpy as np
data_cpu = np.array([[1,1,1,1],
                     [1,2,3,4],
                     [1,3,5,7]])
print data_cpu.shape

# copy test data to the backend (GPU), and allocate an output buffer
data = be.array(data_cpu)

# test our softmax
mysoftmax = MySoftmax()
data[:] = mysoftmax(data)

data_cpu = data.get()
print data_cpu

# validate that our output sums to one
data_sum =  np.sum(data_cpu)
assert 1 - data_sum < 0.0001